#! /usr/bin/env python

############################################################################
# Copyright (c) 2016 Dr. Peter Bunting, Aberystwyth University
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#
# Purpose:  A series of classes to generate image plots from the ptxt text 
#           file format exported for plotting from the rsgis library. 
#
# Author: Pete Bunting
# Email: pfb@aber.ac.uk
# Date: 14/04/16
# Version: 1.0
#
# History:
# Version 1.0 - Created.
#
#############################################################################

import argparse
import csv
import numpy
import os

def readCSVColumns(csvFilePath, imgCol, grdCol, imgColVals, grdColVals):
    with open(csvFilePath) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            imgColVals.append(row[imgCol])
            grdColVals.append(row[grdCol])

def exportCSVFile(outFile, errMatrix, producerAcc, userAcc, producerAccCount, userAccCount, overallAcc, overallCorrCount, kappa, uniqueClasses):
    with open(outFile, 'wt') as csvfile:
        accWriter = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        accWriter.writerow(['Overall Accuracy (%)', round(overallAcc,2)])
        accWriter.writerow(['kappa', round(kappa,2)])
        
        accWriter.writerow([])
        accWriter.writerow(['Counts:'])
        
        numClasses = len(uniqueClasses)
        colNames = [' ']
        for i in range(numClasses):
            colNames.append(uniqueClasses[i])
        colNames.append('User')
        accWriter.writerow(colNames)
        
        for i in range(numClasses):
            row = []
            row.append(uniqueClasses[i])
            for j in range(numClasses):
                row.append(errMatrix[i,j])
            row.append(round(userAccCount[i],2))
            accWriter.writerow(row)
        
        prodRow = ['Producer']
        for i in range(numClasses):
            prodRow.append(round(producerAccCount[i],2))
        prodRow.append(overallCorrCount)
        accWriter.writerow(prodRow)
        
        accWriter.writerow([])
        accWriter.writerow(['Percentage:'])
        
        errMatrixPercent = (errMatrix / numpy.sum(errMatrix)) * 100
        
        colNames = [' ']
        for i in range(numClasses):
            colNames.append(uniqueClasses[i])
        colNames.append('User (%)')
        accWriter.writerow(colNames)
        
        for i in range(numClasses):
            row = []
            row.append(uniqueClasses[i])
            for j in range(numClasses):
                row.append(round(errMatrixPercent[i,j],2))
            row.append(round(userAcc[i],2))
            accWriter.writerow(row)
        
        prodRow = ['Producer (%)']
        for i in range(numClasses):
            prodRow.append(round(producerAcc[i],2))
        prodRow.append(round(overallAcc,2))
        accWriter.writerow(prodRow)


def exportToTex(outFile, errMatrix, producerAcc, userAcc, overallAcc, kappa, uniqueClasses):
    # Open File
    outputFile = open(outFile, 'w')
    
    for i in range(len(uniqueClasses)):
        uniqueClasses[i] = uniqueClasses[i].replace("_", "\_")
    
    # Write latex header to file
    outputFile.write('\\documentclass[12pt]{amsart}\n')
    outputFile.write('\\usepackage{geometry}\n')
    outputFile.write('\\geometry{a4paper}\n')
    outputFile.write('\\geometry{landscape}\n')
    outputFile.write('\\usepackage{graphicx}\n')
    outputFile.write('\\usepackage{amssymb}\n')
    outputFile.write('\\usepackage{epstopdf}\n')
    outputFile.write('\\usepackage{colortbl}\n')
    outputFile.write('\\usepackage[table]{xcolor}\n')
    outputFile.write('\\begin{document}\n')
    # Create the latex table
    outputFile.write('\\begin{table}[htbp]\n')
    caption = str("\\caption{Error Matrix. Overall Accuracy: ") + str(round(overallAcc,2)) + str("\% Kappa: ") + str(round((kappa),3)) +  str("}\n")
    outputFile.write(caption)
    outputFile.write('\\begin{center}\n')
    tabularLine = '\\begin{tabular}{|c|'
    for classVal in uniqueClasses:
        tabularLine =  tabularLine + 'c|'
    tabularLine =  tabularLine + 'c|}\n'
    outputFile.write(tabularLine)
    outputFile.write('\\hline\n')
    
    # Create table column titles
    titleLine = ''
    for classVal in uniqueClasses:
        titleLine =  titleLine + '&\\cellcolor{gray!50}\\textbf{' + str(classVal) + '}'
    titleLine =  titleLine + '&\\cellcolor{gray!50}\\textbf{User}\\\\\n'
    outputFile.write(titleLine)
    outputFile.write('\\hline\n')
    
    # Create main table body
    yIdx = 0
    for classValY in uniqueClasses:
        line = '\\cellcolor{gray!50}\\textbf{' + str(classValY) + '}'
        xIdx = 0
        for classValX in uniqueClasses:
            if classValY == classValX:
                line = line + '&\\cellcolor{red!50}' + str(errMatrix[yIdx,xIdx])
            else:
                line = line + '&' + str(errMatrix[yIdx,xIdx])
            xIdx = xIdx + 1    
        line = line + '&\\textbf{' + str(round(userAcc[yIdx],2)) + '\%}\\\\\n'
        yIdx = yIdx + 1
        outputFile.write(line)
        outputFile.write('\\hline\n')
    
    # Add bottom (producers) row
    line = '\\cellcolor{gray!50}\\textbf{Prod}'
    for val in producerAcc:
        line = line + '&\\textbf{' + str(round(val,2)) + '\%}'
    line = line + '&\\textbf{' + str(round(overallAcc,2)) + '\%}\\\\\n'
    outputFile.write(line)
    outputFile.write('\\hline\n')
    
    # End the table and document
    outputFile.write('\\end{tabular}\n')
    outputFile.write('\\end{center}\n')
    outputFile.write('\\label{tab:ErrMatrix}\n')
    outputFile.write('\\end{table}\n')
    outputFile.write('\\end{document}\n')
    
    # Close and flush output file
    outputFile.flush()
    outputFile.close()

def buildMatrix(imgColVals, grdColVals, uniqueElems):
    """
    build matrix
    """
    numClasses = len(uniqueElems)
    numFeats = len(imgColVals)
    errMatrix = numpy.zeros((numClasses,numClasses), dtype=numpy.float)
    
    for i in range(numFeats):
        imgClass = imgColVals[i]
        imgClassIdx = uniqueElems.index(imgClass)
        grdClass = grdColVals[i]
        grdClassIdx = uniqueElems.index(grdClass)
        #print("[" + imgClass + ", " + grdClass + "] = [ " + str(imgClassIdx) + ", " + str(grdClassIdx) + "]")
        errMatrix[imgClassIdx,grdClassIdx] = errMatrix[imgClassIdx,grdClassIdx] + 1
    
    producerAcc = numpy.zeros(numClasses, dtype=numpy.float)
    userAcc = numpy.zeros(numClasses, dtype=numpy.float)
    producerAccCount = numpy.zeros(numClasses, dtype=numpy.float)
    userAccCount = numpy.zeros(numClasses, dtype=numpy.float)
    overallCorrCount = 0.0
    
    for i in range(numClasses):
        corVal = float(errMatrix[i,i])
        sumRow = float(numpy.sum(errMatrix[i,]))
        sumCol = float(numpy.sum(errMatrix[...,i]))
        overallCorrCount = overallCorrCount + corVal

        if sumRow == 0:
            userAcc[i] = 0
            userAccCount[i] = 0
        else:
            userAcc[i] = corVal / sumRow
            userAccCount[i] = sumRow

        if sumCol == 0:
            producerAcc[i] = 0
            producerAccCount[i] = 0
        else:
            producerAcc[i] = corVal / sumCol
            producerAccCount[i] = sumCol
    
    overallAcc = (overallCorrCount / numpy.sum(errMatrix)) * 100
    producerAcc = producerAcc * 100
    userAcc = userAcc * 100
    
    kappaPartA = overallCorrCount * numpy.sum(producerAccCount)
    kappaPartB = numpy.sum(userAccCount * producerAccCount)
    kappaPartC = numpy.sum(errMatrix) * numpy.sum(errMatrix)

    kappa = float(kappaPartA - kappaPartB) / float(kappaPartC - kappaPartB)
    
    return errMatrix, producerAcc, userAcc, producerAccCount, userAccCount, overallAcc, overallCorrCount, kappa

def produceErrorMatrix(inputCSVs, outputCSVFile, outputTEXFile, outTex, outPDF, imgCol, grdCol):
    """
    Read data, calc and output.
    """
    imgColVals = []
    grdColVals = []
    
    for csvFile in inputCSVs:
        readCSVColumns(csvFile, imgCol, grdCol, imgColVals, grdColVals)
        
    #print(imgColVals)
    #print(grdColVals)
    
    allVals = imgColVals + grdColVals
    
    uniqueClasses = list(set(allVals))
    print(uniqueClasses)
    
    errMatrix, producerAcc, userAcc, producerAccCount, userAccCount, overallAcc, overallCorrCount, kappa = buildMatrix(imgColVals, grdColVals, uniqueClasses)
        
    exportCSVFile(outputCSVFile, errMatrix, producerAcc, userAcc, producerAccCount, userAccCount, overallAcc, overallCorrCount, kappa, uniqueClasses)
    
    if outTex:
        exportToTex(outputTEXFile, errMatrix, producerAcc, userAcc, overallAcc, kappa, uniqueClasses)
        if outPDF:
            cmd = "pdflatex " + outputTEXFile
            os.system(cmd)

if __name__ == '__main__':
    """
    The command line user interface
    """
    parser = argparse.ArgumentParser(prog='rsgiscalcaccmatrix.py', description='''Calculate Error Matrix''')

    parser.add_argument("-c", "--csv", type=str, required=True, help='''Output CSV File.''')
    parser.add_argument("-t", "--tex", type=str, help='''Output TEX File.''')    
    parser.add_argument("-p", "--pdf", action='store_true', default=False, help='''Output PDF File.''')
    parser.add_argument("-i", "--inputs", type=str, required=True, nargs='*', help='''Input CSV files''')    
    parser.add_argument("-r", "--imgcol", type=str, required=True, help='''Image Column Name''')
    parser.add_argument("-g", "--grdcol", type=str, required=True, help='''Ground Truth Column Name ''')
    
    # Call the parser to parse the arguments.
    args = parser.parse_args()
    
    outTex = True
    if args.tex == None:
        outTex = False        
    
    produceErrorMatrix(args.inputs, args.csv, args.tex, outTex, args.pdf, args.imgcol, args.grdcol)
    
    